'''
f-IRL: Extract policy/reward from specified expert samples
'''
import sys, os, time
sys.path.append(
    os.path.dirname(os.path.dirname(__file__))
)


import numpy as np
import torch
import gym
from ruamel.yaml import YAML

from firl.divs.f_div_disc import f_div_disc_loss
from firl.divs.f_div import maxentirl_loss
from firl.divs.ipm import ipm_loss
from firl.models.reward import MLPReward
from firl.models.discrim import SMMIRLDisc as Disc
from firl.models.discrim import SMMIRLCritic as Critic
from common.sac import ReplayBuffer, SAC

import envs
from utils import system, collect, logger, eval
from utils.plots.train_plot_high_dim import plot_disc
from utils.plots.train_plot import plot_disc as visual_disc

import datetime
import dateutil.tz
import json, copy

def try_evaluate(itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = antagonist_samples[0].copy()
    assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states.shape[2]), 
                         env_steps, policy_type)
    # eval real reward
    real_return_det = eval.evaluate_real_return(antagonist_sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.record_tabular("Real Det Return", round(real_return_det, 2))

    real_return_sto = eval.evaluate_real_return(antagonist_sac_agent.get_action, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], False)
    metrics['Real Sto Return'] = real_return_sto
    print(f"real sto return avg: {real_return_sto:.2f}")
    logger.record_tabular("Real Sto Return", round(real_return_sto, 2))

    if v['obj'] in ["emd"]:
        eval_len = int(0.1 * len(critic_loss["main"]))
        emd = -np.array(critic_loss["main"][-eval_len:]).mean()
        metrics['emd'] = emd
        logger.record_tabular(f"{policy_type} EMD", emd)
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)
    if "PointMaze" in env_name:
        visual_disc(agent_emp_states, reward_func.get_scalar_reward, disc.log_density_ratio, v['obj'],
                log_folder, env_steps, gym_env.range_lim,
                sac_info, disc_loss, metrics)

    logger.record_tabular(f"{policy_type} Update Time", update_time)
    logger.record_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, real_return_sto

def clip_grad_norm(parameters, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device
    infinite_grad = False
    for g in grads:
        if torch.isfinite(g).all():
            continue
        else:
            infinite_grad = True
            break
    if not infinite_grad:
        
        if norm_type == 'inf':
            norms = [g.detach().abs().max().to(device) for g in grads]
        
            total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
        
        else:
            total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
        
        clip_coef = max_norm / (total_norm + 1e-6)
        # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
        # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
        # when the gradients do not reside in CPU memory.
        clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
        #for g in grads:
        #    g.detach().mul_(clip_coef_clamped.to(g.device))
        return total_norm
    else:
        print("Clip Infinite grads")
        if norm_type == 'inf':
            total_norm = 1.
        else:
            total_norm = torch.norm(torch.stack([torch.norm((1. - torch.isfinite(g.detach()).float()) , norm_type).to(device) for g in grads]), norm_type)
        for g in grads:
            clip_coef = (1. - torch.isfinite(g.detach()).float()) * max_norm / (total_norm + 1e-6)
            clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
            g.data = g.data * clip_coef_clamped
    
def clip_grad_value(parameters, clip_value: float) -> None:
    r"""Clips gradient of an iterable of parameters at specified value.

    Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        clip_value (float or int): maximum allowed value of the gradients.
            The gradients are clipped in the range
            :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
    """
    #if isinstance(parameters, torch.Tensor):
    #    parameters = [parameters]
    clip_value = float(clip_value)
    for p in parameters:
        if p.grad is None:
            continue
        #if not torch.isfinite(p.grad).all():
            #print("Clipping infinite grad")
        p.grad = p.grad.nan_to_num(nan = 0.0, posinf = 1.0, neginf = -1.0)
        if not torch.isfinite(p.grad).all():
            print("Clipping infinite grad Failed???!!!")


def get_pagar_loss(protagonist_samples, antagonist_samples, protagonist_agent, antagonist_agent, reward_func, device, clip_param = 0.2):
    
    protagonist_s, protagonist_a, protagonist_log_a = protagonist_samples
    antagonist_s, antagonist_a, antagonist_log_a = antagonist_samples

     
    protagonist_N, protagonist_T, protagonist_s_d = protagonist_s.shape
    antagonist_N, antagonist_T, antagonist_s_d = antagonist_s.shape
   
    _, _, protagonist_a_d = protagonist_a.shape
    _, _, antagonist_a_d = antagonist_a.shape
 
    assert protagonist_s_d == antagonist_s_d
    assert protagonist_a_d == antagonist_a_d
 

    protagonist_s_vec = protagonist_s.reshape(-1, protagonist_s_d )
    antagonist_s_vec = antagonist_s.reshape(-1, antagonist_s_d)

    protagonist_a_vec = protagonist_a.reshape(-1, protagonist_a_d)
    antagonist_a_vec = antagonist_a.reshape(-1, antagonist_a_d)

    protagonist_log_a_vec = torch.tensor(protagonist_log_a.reshape(-1, 1))
    antagonist_log_a_vec = torch.tensor(antagonist_log_a.reshape(-1, 1))

    antagonist_protagonist_log_a_vec = protagonist_agent.ac.log_prob(torch.FloatTensor(antagonist_s_vec).to(device), torch.FloatTensor(antagonist_a_vec).to(device)).view(antagonist_N * antagonist_T, 1).detach()
    protagonist_antagonist_log_a_vec = antagonist_agent.ac.log_prob(torch.FloatTensor(protagonist_s_vec).to(device), torch.FloatTensor(protagonist_a_vec).to(device)).view(protagonist_N * protagonist_T, 1).detach()

    protagonist_r_vec = reward_func.r(torch.FloatTensor(protagonist_s_vec).to(device)).view(protagonist_N * protagonist_T, 1) # (N,)    
    antagonist_r_vec = reward_func.r(torch.FloatTensor(antagonist_s_vec).to(device)).view(antagonist_N * antagonist_T, 1) # (N,)    
    
    

    pair_r1 = - ((protagonist_antagonist_log_a_vec.exp() / protagonist_r_vec - protagonist_antagonist_log_a_vec.exp()).log())
    pair_ratio1 = torch.exp(protagonist_antagonist_log_a_vec - protagonist_log_a_vec).detach()
    pair_loss1 = (pair_r1 * pair_ratio1)
    pair_loss1 = pair_loss1[torch.isfinite(pair_loss1)].mean() 
    #pair_ids1 = (pair_ratio1 <=  1. + clip_param).float() * (pair_ratio1 >=  1. - clip_param).float()
    #pair_clipped_ratio1 = pair_ratio1 * pair_ids1
    #pair_loss1 = (pair_r1 * pair_clipped_ratio1)
    #pair_loss1 = pair_loss1[torch.isfinite(pair_loss1)].sum() / pair_ids1[torch.isfinite(pair_loss1)].sum()
    
        
    pair_kl1 = torch.nn.functional.mse_loss(protagonist_log_a_vec, protagonist_antagonist_log_a_vec).detach().item()
    #pair_kl1 = torch.sqrt(protagonist_actor(protagonist_states)[0] - antagonist_actor(protagonist_states)[0])
    #pair_kl1 = pair_kl1[torch.isfinite(pair_kl1)].max().detach().item()

    pair_loss1 =  pair_loss1 + pair_kl1 * 4 * 0.99 / (1 - 0.99) * torch.abs(pair_r1.flatten()).max() 
    pair_loss1 = pair_loss1 - pair_r1[torch.isfinite(pair_r1)].mean()
    
    
    pair_r2 = ((antagonist_log_a_vec.exp() / antagonist_r_vec - antagonist_log_a_vec.exp()).log())
    pair_ratio2 = (torch.exp(antagonist_protagonist_log_a_vec - antagonist_log_a_vec)).detach()
    pair_loss2 = (pair_r2 * pair_ratio2)
    pair_loss2 = pair_loss2[torch.isfinite(pair_loss2)].mean() 
    #pair_ids2 = (pair_ratio2 <=  1. + clip_param).float() * (pair_ratio2 >=  1. - clip_param).float()
    #pair_clipped_ratio2 = pair_ratio2 * pair_ids2
    #pair_loss2 = (pair_r2 * pair_clipped_ratio2)
    #pair_loss2 = pair_loss2[torch.isfinite(pair_loss2)].sum() / pair_ids2[torch.isfinite(pair_loss2)].sum()
    
    pair_kl2 = torch.nn.functional.mse_loss(antagonist_log_a_vec, antagonist_protagonist_log_a_vec).detach().item()
    #pair_kl2 = torch.sqrt(antagonist_actor(antagonist_states)[0] - protagonist_actor(antagonist_states)[0])
    #pair_kl2 = pair_kl2[torch.isfinite(pair_kl2)].max().detach().item()

    pair_loss2 = pair_loss2 - pair_kl2 * 4 * 0.99 / (1 - 0.99) * torch.abs(pair_r2.flatten()).max() 
    pair_loss2 = pair_loss2 - pair_r2[torch.isfinite(pair_r2)].mean() 

    
    pair_ratio3 = (torch.exp(pair_r2 - antagonist_log_a_vec.detach()))
    pair_ids3 = (pair_ratio3 <=  1. + clip_param).float() * (pair_ratio3 >=  1. - clip_param).float()
    pair_clipped_ratio3 = torch.clamp(pair_ratio3, 1 - clip_param, 1 + clip_param)# * pair_ids3.detach()
    pair_loss3 = - torch.min(pair_r2 * pair_ratio3, pair_r2 * pair_clipped_ratio3).mean()
    #pair_loss3 = pair_loss3[torch.isfinite(pair_loss3)].sum() / pair_ids3[torch.isfinite(pair_loss3)].sum()
    pair_loss3 = pair_loss3 - pair_r1[torch.isfinite(pair_r1)].mean()
    

    pair_ratio4 = (torch.exp(-pair_r1 - protagonist_log_a_vec.detach()))
    pair_ids3 = (pair_ratio3 <=  1. + clip_param).float() * (pair_ratio3 >=  1. - clip_param).float()
    pair_clipped_ratio4 = torch.clamp(pair_ratio4, 1 - clip_param, 1 + clip_param)# * pair_ids3.detach()
    pair_loss4 = -torch.min(-pair_r1 * pair_ratio4, -pair_r1 * pair_clipped_ratio4).mean()
    #pair_loss3 = pair_loss3[torch.isfinite(pair_loss3)].sum() / pair_ids3[torch.isfinite(pair_loss3)].sum()
    pair_loss4 = pair_loss4 - pair_r1[torch.isfinite(pair_r1)].mean()
    """
    pair_loss = pair_loss1 + pair_loss2
    
    pair_loss0 = 0
    
        
        protagonist_expert_r_vec, _, _ = reward_function(torch.cat((demonstration_states, protagonist_expert_actions[i]), dim = 1))
        protagonist_expert_r = (antagonist_protagonist_expert_log_a_vec[i].exp() / protagonist_expert_r_vec - antagonist_protagonist_expert_log_a_vec[i].exp()).log()

        antagonist_expert_r_vec, _, _ = reward_function(torch.cat((demonstration_states, antagonist_expert_actions[i]), dim = 1))
        antagonist_expert_r = (antagonist_expert_log_a_vec[i].exp() / antagonist_expert_r_vec - antagonist_expert_log_a_vec[i].exp()).log()
        
        pair_loss0_i = (protagonist_expert_log_a_vec[i].exp() * protagonist_expert_r - antagonist_expert_log_a_vec[i].exp() * antagonist_expert_r)
        pair_loss0_i = pair_loss0_i[torch.isfinite(pair_loss0_i)]
        pair_loss0 += pair_loss0_i.mean()
    pair_loss0 /= 20
    """

    #r = (antagonist_expert_r_vec /expert_r_vec - antagonist_expert_r_vec).log()
    #ratio = ((protagonist_expert_r_vec - antagonist_expert_r_vec).detach()) 
    #pair_loss0 =  (r * ratio)
    
    #ratio = (protagonist_expert_r_vec - antagonist_expert_r_vec).detach()  / r.exp().detach()
    #pair_loss0 =  (r * ratio) 
    #pair_loss0 = pair_loss0[pair_loss0 < 0]
    #pair_loss0 = pair_loss0[torch.isfinite(pair_loss0)].log().mean().exp()
    #print(pair_loss1, pair_loss2, pair_loss0)
    
    pair_loss = pair_loss1 + pair_loss2 + (pair_loss4 if torch.isfinite(pair_loss4).all() else 0.)  + (pair_loss3 if torch.isfinite(pair_loss3).all() else 0.) 
    
    return pair_loss
    


if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    seed = v['seed']
    num_expert_trajs = v['irl']['expert_episodes']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:{v['cuda']}" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # assumptions
    assert v['obj'] in ['fkl', 'rkl', 'js', 'emd', 'maxentirl']
    assert v['IS'] == False

    # logs
    exp_id = f"logs/{env_name}/exp-{num_expert_trajs}/pagar_{v['obj']}" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/' + now.strftime('%Y_%m_%d_%H_%M_%S')
    logger.configure(dir=log_folder)            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp firl/pagar_samples.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant_{pid}.yml')
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'))
    os.makedirs(os.path.join(log_folder, 'model'))

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    expert_trajs = torch.load(f'expert_data/states/{env_name}.pt').numpy()[:, :, state_indices]
    expert_trajs = expert_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    print(expert_trajs.shape, expert_samples.shape) # ignored starting state

    # Initilialize reward as a neural network
    reward_func = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    reward_optimizer = torch.optim.Adam(reward_func.parameters(), lr=v['reward']['lr'], 
        weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    
    # Initilialize discriminator
    if v['obj'] in ["emd"]:
        critic = Critic(len(state_indices), **v['critic'], device=device)
    elif v['obj'] != 'maxentirl':
        disc = Disc(len(state_indices), **v['disc'], device=device)

    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    for itr in range(v['irl']['n_itrs']):

        if v['sac']['reinitialize'] or itr == 0:
            # Reset SAC agent with old policy, new environment, and new replay buffer
            print("Reinitializing sac")
            replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                
            protagonist_sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                **v['sac']
            )

            antagonist_sac_agent = SAC(env_fn, replay_buffer,
                steps_per_epoch=v['env']['T'],
                update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                max_ep_len=v['env']['T'],
                seed=seed,
                start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                reward_state_indices=state_indices,
                device=device,
                **v['sac']
            )
        
        protagonist_sac_agent.reward_function =  antagonist_sac_agent.reward_function = reward_func.get_scalar_reward # only need to change reward in sac
        print("Protagonist:")
        protagonist_sac_info = protagonist_sac_agent.learn_mujoco(print_out=True)
        print("Antagonist:")
        antagonist_sac_info = antagonist_sac_agent.learn_mujoco(print_out=True)

        start = time.time()
        protagonist_samples = collect.collect_trajectories_policy_single(gym_env, protagonist_sac_agent, 
                        n = v['irl']['training_trajs'], state_indices=state_indices)
        antagonist_samples = collect.collect_trajectories_policy_single(gym_env, antagonist_sac_agent, 
                        n = v['irl']['training_trajs'], state_indices=state_indices)
        # Fit a density model using the samples
        protagonist_agent_emp_states = protagonist_samples[0].copy()
        protagonist_agent_emp_states = protagonist_agent_emp_states.reshape(-1,protagonist_agent_emp_states.shape[2]) # n*T states
        print(f'collect trajs {time.time() - start:.0f}s', flush=True)

        antagonist_agent_emp_states = antagonist_samples[0].copy()
        antagonist_agent_emp_states = antagonist_agent_emp_states.reshape(-1,antagonist_agent_emp_states.shape[2]) # n*T states
        print(f'collect trajs {time.time() - start:.0f}s', flush=True)
        # print(agent_emp_states.shape)

        start = time.time()
        if v['obj'] in ["emd"]:
            critic_loss = critic.learn(expert_samples.copy(), antagonist_agent_emp_states, iter=v['critic']['iter'])
        elif v['obj'] != 'maxentirl': # learn log density ratio
            disc_loss = disc.learn(expert_samples.copy(), antagonist_agent_emp_states, iter=v['disc']['iter'])
        print(f'train disc {time.time() - start:.0f}s', flush=True)

        # optimization w.r.t. reward
        reward_losses = []
        for _ in range(v['reward']['gradient_step']):
            if v['irl']['resample_episodes'] > v['irl']['expert_episodes']:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=True)
                expert_trajs_train = expert_trajs[expert_res_indices].copy() # resampling the expert trajectories
            elif v['irl']['resample_episodes'] > 0:
                expert_res_indices = np.random.choice(expert_trajs.shape[0], v['irl']['resample_episodes'], replace=False)
                expert_trajs_train = expert_trajs[expert_res_indices].copy()
            else:
                expert_trajs_train = None # not use expert trajs

            if v['obj'] in ['fkl', 'rkl', 'js']:
                loss, _ = f_div_disc_loss(v['obj'], v['IS'], antagonist_samples, disc, reward_func, device, expert_trajs=expert_trajs_train)             
            elif v['obj'] in ['fkl-state', 'rkl-state', 'js-state']:
                loss = f_div_current_state_disc_loss(v['obj'], samples, disc, reward_func, device, expert_trajs=expert_trajs_train)
            elif v['obj'] == 'maxentirl':
                loss = maxentirl_loss(v['obj'], samples, expert_samples, reward_func, device)
            elif v['obj'] == 'emd':
                loss, _ = ipm_loss(v['obj'], v['IS'], samples, critic.value, reward_func, device, expert_trajs=expert_trajs_train)  
            

            pagar_loss = get_pagar_loss(protagonist_samples, antagonist_samples, protagonist_sac_agent, antagonist_sac_agent, reward_func, device)
            tot_loss = loss + pagar_loss * 1e-3

            reward_losses.append(tot_loss.item())
            print(f"{v['obj']}_loss: {loss}, pagar_{v['obj']}_loss: {pagar_loss}, total_loss: {tot_loss}")
            reward_optimizer.zero_grad()
            loss.backward()
            clip_grad_value(reward_func.parameters(), 1)
            reward_optimizer.step()

        # evaluating the learned reward
        real_return_det, real_return_sto = try_evaluate(itr, "Running", antagonist_sac_info)
        if real_return_det > max_real_return_det and real_return_sto > max_real_return_sto:
            max_real_return_det, max_real_return_sto = real_return_det, real_return_sto
            torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), 
                    f"model/reward_model_itr{itr}_det{max_real_return_det:.0f}_sto{max_real_return_sto:.0f}.pkl"))

        logger.record_tabular("Itration", itr)
        logger.record_tabular("Reward Loss", loss.item())
        logger.record_tabular("PAGAR Loss", pagar_loss.item())
        if v['sac']['automatic_alpha_tuning']:
            logger.record_tabular("protagonist_alpha", protagonist_sac_agent.alpha.item())
            logger.record_tabular("antagonist_alpha", antagonist_sac_agent.alpha.item())

        # if v['irl']['save_interval'] > 0 and (itr % v['irl']['save_interval'] == 0 or itr == v['irl']['n_itrs']-1):
        #     torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), f"model/reward_model_{itr}.pkl"))

        logger.dump_tabular()